#!/usr/bin/env python3
# Usage:
#  PYTHONPATH=src ./encode.py <file|directory|glob> /path/to/output.npz
#  PYTHONPATH=src ./train --dataset /path/to/output.npz

import argparse
import numpy as np
import sys

from ftfy import fix_text

### inline this import
#import scrap_for_each_line
import tqdm
import sys
try:
  from tensorflow import io
  gfile = io.gfile
except:
  try:
    from tensorflow import gfile
  except ImportError:
    gfile = None
    from smart_open import open
import time

def file_size(f):
  if isinstance(f, str):
    with try_open(f) as f:
      return file_size(f)
  #if isinstance(f, gfile.GFile):
  if 'gfile.GFile' in str(type(f)):
    return f.size()
  else:
    was = f.tell()
    try:
      f.seek(0, 2)
      pos = f.tell()
    finally:
      f.seek(was, 0)
    return pos

def try_open(filename, *args, **kws):
  if gfile is not None and filename.startswith("gs://"):
    try:
      from tensorflow import io
      gfile = io.gfile
    except:
      from tensorflow import gfile
    return gfile.GFile(filename, *args, **kws)
  else:
    encoding = kws.pop('encoding', 'utf-8')
    return open(filename, *args, **kws, encoding=encoding)


from contextlib import contextmanager


@contextmanager
def nullcontext():
  yield

import io

def for_each_line(f, total=None, verbose=True, ignore_errors=True, message=None, silent=False):
    if isinstance(f, str):
      with try_open(f) as infile:
        for i, line in for_each_line(infile, total=total, verbose=verbose, ignore_errors=ignore_errors, message=message):
          yield i, line
    elif isinstance(f, list):
      i = 0
      for line in tqdm.tqdm(f) if not silent else f:
        yield i, line
        i += 1
    else:
      i = 0
      prev = None
      try:
        size = file_size(f)
      except io.UnsupportedOperation:
        size = -1
      pos = 0
      prev_pos = 0
      n = 0
      while True:
        try:
          with (tqdm.tqdm(total=size) if not silent and size < 0 else nullcontext()) as pbar:
            for line in f:
              yield i, line
              i += 1
              pos += len(line)
              if pbar is not None:
                pbar.update(pos - prev_pos)
              prev = line
              prev_pos = pos
            break
        except UnicodeDecodeError:
          n += 1
          if verbose:
            sys.stderr.write('Error on line %d after %s\n' % (i+n+1, repr(prev)))
          if not ignore_errors:
            raise

#### </import scrap_for_each_line>

parser = argparse.ArgumentParser(
    description='Use FTFY to prepare a dataset for training.',
    formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('infile', metavar='PATH', type=str, help='Input file, directory, or glob pattern (utf-8 text).')
parser.add_argument('--outfile', default="-", type=str, help='Output file path, or - for stdout')
parser.add_argument('--silent', action='store_true', help="Don't show progress bar")

def fix(line):
  if '\x04' in line:
    return '\x04'.join([fix(part) for part in line.split('\x04')])
  else:
    fixed = fix_text(line)
    # replace unicode … with ... which ftfy doesn't do by default
    # NOTE: this departs from openai's convention of calling
    # ftfy.fix_text() with default arguments. In particular,
    # OpenAI's GPT-2 models do generate unicode ellipses.
    # Nonetheless, we replace unicdoe ellipses with ... to
    # increase the chances of semantic understanding.
    fixed = fixed.replace(' …', '...') # first pass: convert "foo  …" to "foo..."
    #fixed = fixed.replace(' …', '...') # second pass: convert "foo …" to "foo..."
    fixed = fixed.replace('…', '...') # final pass: convert "foo…" to "foo..."
    return fixed

def main():
    args = parser.parse_args()
    out = sys.stdout if args.outfile == '-' else open(args.outfile, "w")
    if args.outfile == '-':
      args.silent = True
    i = 0
    with (open(args.infile) if args.infile != '-' else sys.stdin) as f:
      for i, line in for_each_line(f, silent=args.silent):
        fixed = fix(line)
        out.write(fixed)
        i += 1
        if i % 100 == 0:
          out.flush()

if __name__ == '__main__':
    main()

